{% from 'macros.jinja' import edge_param, edge_param_dict, node_llm_config_to_dspy_lm %}

import time
import asyncio
from typing import Dict, Any, Literal, Optional

import dspy
import langwatch
from langwatch_nlp.studio.dspy import (
    LangWatchWorkflowModule,
    PredictionWithMetadata,
    EvaluationResultWithMetadata,
    PredictionWithEvaluationAndMetadata,
    {# We can probably remove this import once we have LLM Answer Match migrated to langevals #}
    LLMConfig,
    TemplateAdapter,
)

class WorkflowModule(LangWatchWorkflowModule):
    def __init__(self, run_evaluations: bool = False):
        super().__init__()

        {% for node in nodes|selectattr("type", "ne", "entry")|selectattr("type", "ne", "end")|list %}
        self.{{ node.id }} = self.wrapped({{ node_templates[node.id][1] }}, node_id="{{ node.id }}", {% if node.type == "evaluator" or node.data.behave_as == "evaluator" %}run=run_evaluations{% endif %})(
            {% for key, value in node_templates[node.id][2].items() %}
            {% if node.type == "retriever" and key == "rm" %}
            {{ key }}={{ value }},
            {% else %}
            {{ key }}={{ value.__repr__() }},
            {% endif %}
            {% endfor %}
        )
        {% endfor %}

    {% set input_args = [] %}
    {% set forward_args = [] %}
    {% for input_field in inputs %}
        {% set _ = input_args.append(input_field.identifier + ": " + FIELD_TYPE_TO_DSPY_TYPE[input_field.type.value] + " = None") %}
        {% set _ = forward_args.append(input_field.identifier + "=" + input_field.identifier) %}
    {% endfor %}
    {% if use_kwargs %}
    {% set input_args = ["**kwargs"] %}
    {% set forward_args = ["**kwargs"] %}
    {% endif %}

    {% if not do_not_trace %}
    @langwatch.span(type="workflow")
    {% endif %}
    {% if handle_errors %}
    def forward(self, {{ input_args | join(", ") }}) -> dspy.Prediction:
        try:
            return self._forward({{ forward_args | join(", ") }})

        except Exception as e:
            return PredictionWithEvaluationAndMetadata(
                error=e,
                cost=self.cost,
                duration=self.duration,
            )

    def _forward(self, {{ input_args | join(", ") }}) -> dspy.Prediction:
    {% else %}
    def forward(self, {{ input_args | join(", ") }}) -> dspy.Prediction:
    {% endif %}
        config = {}
        {% if workflow.default_llm %}
        config["lm"] = {{ node_llm_config_to_dspy_lm(workflow.default_llm) }}
        {% endif %}
        {% if workflow.template_adapter == "default" %}
        config["adapter"] = TemplateAdapter()
        {% endif %}
        with dspy.context(**config):
            self.cost = 0
            self.duration = 0

            {# First build dependency graph #}
            {% set dependency_graph = {} %}
            {% for node in nodes %}
                {% set _ = dependency_graph.update({node.id: []}) %}
            {% endfor %}
            {% for edge in workflow.edges %}
                {% if edge.target in dependency_graph %}
                    {% set _ = dependency_graph[edge.target].append(edge.source) %}
                {% endif %}
            {% endfor %}

            {# Now generate execution layers #}
            {% set executed = {'nodes': ['entry']} %}
            {% set layers = [] %}
            {% set previous_executed_count = {'value': 0} %}

            {% if debug_level > 1 %}
            # Dependency graph: {{ dependency_graph }}
            {% endif %}

            {# Build layers until all nodes are executed #}
            {% set remaining_nodes = nodes|selectattr("type", "ne", "entry")|selectattr("type", "ne", "end")|list %}
            {% for i in range(nodes|length) %}
                {% set current_layer = [] %}

                {% for node in remaining_nodes %}
                    {% set ns = namespace(can_execute=true) %}
                    {% for dep in dependency_graph[node.id] %}
                        {% if dep not in executed.nodes %}
                            {% set ns.can_execute = false %}
                        {% endif %}
                    {% endfor %}
                    {% if debug_level > 1 %}
            # {{ i }}: Node: {{ node.id }}, can execute: {{ ns.can_execute }}, executed: {{ executed.nodes }}, dependencies: {{ dependency_graph[node.id] }}
                    {% endif %}
                    {% if ns.can_execute and node.id not in executed.nodes %}
                        {% set _ = current_layer.append(node) %}
                    {% endif %}
                {% endfor %}
                {% if current_layer %}
                    {% set _ = layers.append(current_layer) %}
                    {% for node in current_layer %}
                        {% set _ = executed.nodes.append(node.id) %}
                    {% endfor %}
                {% endif %}
            {% endfor %}

            {% for layer in layers if debug_level > 0 %}
            # Layer {{ loop.index }}: {{ layer|map(attribute="id")|list }}
            {% endfor %}

            {# Generate code for each layer #}
            {% for layer in layers %}
                {% if layer|length > 1 %}
            {{ layer|map(attribute="id")|list|join(", ") }} = self.run_in_parallel(
                    {% for node in layer %}
                (self.{{ node.id }}, {
                        {% for edge in workflow.edges %}
                            {% if edge.target == node.id %}
                                {% set source_parts = edge.sourceHandle.split(".") %}
                                {% set target_parts = edge.targetHandle.split(".") %}
                                {{ edge_param_dict(edge, source_parts, target_parts, use_kwargs) }}
                            {% endif %}
                        {% endfor %}
                }),
                    {% endfor %}
            )

                {% else %}
            {{ layer[0].id }} = self.{{ layer[0].id }}(
                    {% for edge in workflow.edges %}
                        {% if edge.target == layer[0].id %}
                            {% set source_parts = edge.sourceHandle.split(".") %}
                            {% set target_parts = edge.targetHandle.split(".") %}
                            {{ edge_param(edge, source_parts, target_parts, use_kwargs) }}
                        {% endif %}
                    {% endfor %}
            )
                {% endif %}
            {% endfor %}

            {# Finally execute end node #}
            {% set ns = namespace(end_node=None) %}
            {% for node in nodes %}
                {% if node.type == "end" %}
                    {% set ns.end_node = node %}
                {% endif %}
            {% endfor %}

            return PredictionWithEvaluationAndMetadata(
                {% for node in nodes|selectattr("type", "ne", "entry")|selectattr("type", "ne", "end")|selectattr("type", "ne", "evaluator")|selectattr("data.behave_as", "ne", "evaluator")|list %}
                {{ node.id }}={{ node.id }},
                {% endfor %}
                {% if ns.end_node %}
                end={
                    {% for edge in workflow.edges %}
                        {% if edge.target == ns.end_node.id %}
                            {% set source_parts = edge.sourceHandle.split(".") %}
                            {% set target_parts = edge.targetHandle.split(".") %}
                            {% set target_handle = target_parts[1] if target_parts[1] else target_parts[0] %}
                            {% if edge.source == "entry" %}
                    "{{ target_handle }}": {% if use_kwargs %}kwargs.get("{{ source_parts[1] }}"){% else %}{{ source_parts[1] }}{% endif %},
                            {% else %}
                    "{{ target_handle }}": {{ edge.source }}.{{ source_parts[1] }},
                            {% endif %}
                        {% endif %}
                    {% endfor %}
                },
                {% endif %}
                evaluations={
                    {% for node in nodes if node.type == "evaluator" or (node.data.behave_as == "evaluator" and node.type != "end") %}
                    "{{ node.id }}": {{ node.id }},
                    {% endfor %}
                },
                cost=self.cost,
                duration=self.duration,
            )

{% for node_id, node_template in node_templates.items() %}
{{ node_template[0] }}
{% endfor %}